import abc
import jax.numpy as np
import jax


class ChunkedDist(abc.ABC):
    """
    
    A base class to define chunked distributions.
    The joint distribution, p(zₚ, {zᵢ}ᴺ), is factorized as
    p(zₚ)∏ᵢp(zᵢ|zₚ). zₚ is parent variable and {zᵢ}ᴺ is collection of
    children variables; where, each of the zᵢ can be multivariate
    random variable itself. Then, a chunk is defined as (zₚ,zᵢ). 

    ChunkedDist allows you to sample and evaluate z = (zₚ, {zᵢ}ᴺ), 
    or a single chunk (zₚ,zᵢ).

    ChunkedDist works even if zₚ is not present in the distributions, that
    is the distribution has only local variables zᵢ.
    
    """

    def __init__(self, N_chunk):
        self.N_chunk = N_chunk

    @abc.abstractmethod
    def sample(self, rng_key, params=None, chunk=None):
        """
        draw a sample or a sample chunk
        some distributions (targets) will have no parameters
        some distributions (posteriors) won't implement this method
        you gotta be very careful about random numbers -- samples for different
        chunks with the same random number must agree!
        """
        pass

    @abc.abstractmethod
    def log_prob(self, z, params=None, chunk=None):
        """
        eval a sample or a sample chunk
        some distributions (targets) will have no parameters
        """
        pass

    def _log_prob(self, z, params=None, chunk=None):
        """
        An internal function to help evaluate log prob
        """
        raise NotImplementedError

    def _log_prob_eval(self, z, params=None, chunk=None):
        """
        An internal function useful for debugging and monitoring.
        """
        raise NotImplementedError

    def test_logp_consistency(self, params, z, z_chunks, eps = 1e-5, return_diff=False):
        log_p_chunks = jax.vmap(self.log_prob, in_axes = (0, None, 0))(
                        z_chunks, params, np.arange(self.N_chunk))
        log_p = self.log_prob(z, params, None)
        assert np.abs(log_p - np.mean(log_p_chunks)) < eps * np.abs(log_p)
        if return_diff:
            return np.abs(log_p - np.mean(log_p_chunks))/np.abs(log_p)

    def test_cross_consistency(
        self, params, other_dist, other_params, eps=1e-5, return_diff=False):
        """
        Draw a sample from other distribution, then check that this distribution
        gives consistent log_probs
        """
        assert self.N_chunk == other_dist.N_chunk
        rng_key = jax.random.PRNGKey(0)
        z = other_dist.sample(rng_key, other_params, None)
        z_chunks = jax.vmap(other_dist.sample, in_axes=(None, None, 0))(
            rng_key, other_params, np.arange(self.N_chunk))
        return self.test_logp_consistency(
            params, z, z_chunks, eps = eps, return_diff=return_diff)

    def test_self_consistency(self, params, eps = 1e-5, return_diff=False):
        return self.test_cross_consistency(
            params, self, params, eps = eps, return_diff=return_diff)

class BranchDist(ChunkedDist):
    """
    A class to define branched distributions.

    Branched distribution is natural when there is a zₚ 
    random variable and at least one child (zᵢ).   

    """
    def __init__(self, N_chunk):
        super(BranchDist,self).__init__(N_chunk)

    def sample(self, rng_key, params, chunk, **kwargs):
        rng_key, rng_subkey = jax.random.split(rng_key)
        θ = self.sample_parent(rng_subkey, params, **kwargs)

        _sample_child = jax.partial(self.sample_child, **kwargs)

        if chunk is not None:
            rng_key = jax.random.fold_in(rng_key, chunk)
            return θ, _sample_child(rng_key, params, θ, chunk)

        else:
            rng_keys = jax.vmap(
                lambda i: jax.random.fold_in(rng_key, i))(
                    np.arange(self.N_chunk))
            w = jax.vmap(
                _sample_child, in_axes=(0, None, None, 0))(
                    rng_keys, params, θ, np.arange(self.N_chunk))
            return θ, w

    def _log_prob(self, z, params, chunk, **kwargs):
        θ = z[0]
        lθ = self.eval_parent(θ, params, **kwargs)

        _eval_child = jax.partial(self.eval_child, **kwargs)

        if chunk is not None:
            wi = z[1]
            lwi = _eval_child(θ, params, wi, chunk)
            return lθ, lwi
        else:
            w = z[1]
            lw = jax.vmap(
                        _eval_child,
                        in_axes=(None, None, 0, 0))(
                            θ, params, w, np.arange(self.N_chunk))
            return lθ, lw
        
    def log_prob(self, z, params, chunk, **kwargs):
        if chunk is not None:
            lθ, lwi = self._log_prob(z, params, chunk, **kwargs)
            return (lθ + self.N_chunk * lwi) 
        else:
            lθ, lw = self._log_prob(z, params, chunk, **kwargs)
            return lθ + np.sum(lw)


    @abc.abstractmethod
    def sample_parent(self, rng_key, params, **kwargs):
        pass

    @abc.abstractmethod
    def sample_child(self, rng_key, params, θ, chunk, **kwargs):
        pass

    @abc.abstractmethod
    def eval_parent(self, θ, params, **kwargs):
        pass

    @abc.abstractmethod
    def eval_child(self, θ, params, wi, chunk, **kwargs):
        pass



class SimpleBranchDist(BranchDist):

    @abc.abstractmethod
    def parent_dist(self, params, **kwargs):
        """
        A function to return the parent distribution. A parent distribution 
        is expected to support two functions:
            parent_dist(params).sample(rng_key)
            parent_dist(params).log_prob(θ)

        Args:
            params :  params that parameterize the distribution
        """
        pass

    @abc.abstractmethod
    def child_dist(self, θ, params, chunk, **kwargs):
        """
        A function to return the child distribution. A child distribution 
        is expected to support two functions:
            child_dist(θ, params, chunk).sample(rng_key)
            child_dist(θ, params, chunk).log_prob(wi)

        Args:
            params :  params that parameterize the distribution
        """
        pass

    def sample_parent(self, rng_key, params, **kwargs):
        return self.parent_dist(params, **kwargs).sample(rng_key)
    
    def sample_child(self, rng_key, params, θ, chunk, **kwargs):
        return self.child_dist(θ, params, chunk, **kwargs).sample(rng_key)
    
    def eval_parent(self, θ, params, **kwargs):
        return self.parent_dist(params, **kwargs).log_prob(θ)
    
    def eval_child(self, θ, params, wi, chunk, **kwargs):
        return self.child_dist(θ, params, chunk, **kwargs).log_prob(wi)


class SimpleBranchDistWithSampleEval(SimpleBranchDist):

    def sample_and_log_prob(self, rng_key, params, chunk, **kwargs):
        rng_key, rng_subkey = jax.random.split(rng_key)
        θ, lθ = self.sample_and_eval_parent(rng_subkey,params, **kwargs)
        _sample_and_eval_child = jax.partial(self.sample_and_eval_child, **kwargs)
        if chunk is not None:
            rng_key = jax.random.fold_in(rng_key, chunk)
            w, lw = _sample_and_eval_child(rng_key, params, θ, chunk)
            return (θ, w), (lθ+self.N_chunk*lw)
        else:
            rng_keys = jax.vmap(
                            lambda i: jax.random.fold_in(rng_key, i))(
                                                        np.arange(self.N_chunk))
            w, lw = jax.vmap(
                _sample_and_eval_child, in_axes=(0, None, None, 0))(
                    rng_keys, params, θ, np.arange(self.N_chunk))
            return (θ, w), (lθ+np.sum(lw))

    def sample_and_eval_parent(self, rng_key, params, **kwargs):
        return self.parent_dist(params, **kwargs).sample_and_log_prob(rng_key)

    def sample_and_eval_child(self, rng_key, params, θ, chunk, **kwargs):
        return self.child_dist(θ, params, chunk, **kwargs).sample_and_log_prob(rng_key)
